{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#faiss基础模块\n",
    "faiss中的索引基于几个基础算法构建，只不过在faiss中是一种高效的实现。他们分别是k-means聚类、PCA降维、PQ编码、解码。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##k-means聚类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[3.046999  3.0121088 3.0124333 ... 3.0203993 3.011947  2.9342847]\n",
      " [2.7589808 3.0725713 2.9360871 ... 3.0773525 2.902585  2.995511 ]\n",
      " [3.1167192 2.9537685 2.9987445 ... 3.0199993 2.9278672 3.050025 ]\n",
      " ...\n",
      " [2.9502757 3.0440164 2.9121387 ... 2.9652288 3.2078865 3.009649 ]\n",
      " [2.9459333 3.0297534 2.9002755 ... 2.9255435 2.8951385 2.9468067]\n",
      " [2.9947238 3.1082706 2.9418213 ... 3.0144033 3.046606  2.9184723]]\n"
     ]
    }
   ],
   "source": [
    "#导入faiss\n",
    "import sys\n",
    "sys.path.append('/home/maliqi/faiss/python/')\n",
    "import faiss\n",
    "\n",
    "#数据\n",
    "import numpy as np \n",
    "d = 512          #维数\n",
    "n_data = 2000   \n",
    "np.random.seed(0) \n",
    "data = []\n",
    "mu = 3\n",
    "sigma = 0.1\n",
    "for i in range(n_data):\n",
    "    data.append(np.random.normal(mu, sigma, d))\n",
    "data = np.array(data).astype('float32')\n",
    "\n",
    "# 聚类\n",
    "ncentroids = 1024\n",
    "niter = 20\n",
    "verbose = True\n",
    "d = data.shape[1]\n",
    "kmeans = faiss.Kmeans(d, ncentroids, niter, verbose)\n",
    "kmeans.train(data)\n",
    "\n",
    "#输出聚类中心\n",
    "print(kmeans.centroids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[4.899538 ]\n",
      " [2.2404225]\n",
      " [3.0874515]\n",
      " [4.472025 ]\n",
      " [2.1018007]]\n",
      "[[ 61]\n",
      " [767]\n",
      " [393]\n",
      " [415]\n",
      " [175]]\n"
     ]
    }
   ],
   "source": [
    "#计算某个向量属于哪一个子类，返回聚类中心次序和L2距离\n",
    "D, I = kmeans.index.search(data[:5], 1)\n",
    "print(D)\n",
    "print(I)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 3.2460938e+00  4.0878906e+00  4.2246094e+00  4.2275391e+00\n",
      "   4.3750000e+00]\n",
      " [-9.7656250e-04  8.4746094e+00  8.5214844e+00  8.7089844e+00\n",
      "   8.7460938e+00]\n",
      " [-1.9531250e-03  8.4355469e+00  8.4599609e+00  8.5195312e+00\n",
      "   8.6240234e+00]\n",
      " ...\n",
      " [ 9.7656250e-04  8.8291016e+00  8.8359375e+00  8.8916016e+00\n",
      "   8.9335938e+00]\n",
      " [ 2.5214844e+00  3.0458984e+00  3.2636719e+00  5.7021484e+00\n",
      "   5.9355469e+00]\n",
      " [ 2.2187500e+00  2.2197266e+00  6.6103516e+00  6.6591797e+00\n",
      "   6.6679688e+00]]\n",
      "[[1083  472  356 1892   34]\n",
      " [1411  414  198  620 1129]\n",
      " [ 140  317 1686   24  402]\n",
      " ...\n",
      " [ 753 1776  331  389  279]\n",
      " [ 432 1096  240  879  329]\n",
      " [ 625 1211  751  106 1318]]\n"
     ]
    }
   ],
   "source": [
    "#返回距离某个聚类中心最近的5个向量\n",
    "index = faiss.IndexFlatL2 (d)\n",
    "index.add (data)\n",
    "D, I = index.search (kmeans.centroids, 5)\n",
    "print(D)\n",
    "print(I)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##PCA降维"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2000, 64)\n"
     ]
    }
   ],
   "source": [
    "mat = faiss.PCAMatrix (512, 64)  # 从512维降为64维\n",
    "mat.train(data)\n",
    "assert mat.is_trained\n",
    "tr = mat.apply_py(data)\n",
    "print(tr.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##PQ编码/解码\n",
    "ProductQuantizer对象可以将向量编码为code。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0008765541\n"
     ]
    }
   ],
   "source": [
    "d = 512  # 数据维度\n",
    "cs = 4  # code size (bytes)\n",
    "\n",
    "# 训练数据集\n",
    "xt = data  #训练集\n",
    "\n",
    "# dataset to encode (could be same as train)\n",
    "x = data\n",
    "\n",
    "pq = faiss.ProductQuantizer(d, cs, 8)\n",
    "pq.train(xt)\n",
    "\n",
    "# encode编码 \n",
    "codes = pq.compute_codes(x)\n",
    "\n",
    "# decode解码\n",
    "x2 = pq.decode(codes)\n",
    "\n",
    "# 编码-解码后与原始数据的差\n",
    "avg_relative_error = ((x - x2)**2).sum() / (x ** 2).sum()\n",
    "print(avg_relative_error)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "标量量化器（scalar quantizer）与之类似。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6.7287445e-08\n"
     ]
    }
   ],
   "source": [
    "d = 512  # 数据维度\n",
    "\n",
    "# 训练集\n",
    "xt = data\n",
    "\n",
    "# dataset to encode (could be same as train)\n",
    "x = data\n",
    "\n",
    "# QT_8bit allocates 8 bits per dimension (QT_4bit also works)\n",
    "sq = faiss.ScalarQuantizer(d, faiss.ScalarQuantizer.QT_8bit)\n",
    "sq.train(xt)\n",
    "\n",
    "# encode 编码\n",
    "codes = sq.compute_codes(x)\n",
    "\n",
    "# decode 解码\n",
    "x2 = sq.decode(codes)\n",
    "\n",
    "# 计算编码-解码后与原始数据的差\n",
    "avg_relative_error = ((x - x2)**2).sum() / (x ** 2).sum()\n",
    "print(avg_relative_error)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
